"""
Phase 5: CBI and Inflation Targeting as Regime Moderators
Do central bank independence and inflation targeting regimes moderate the
demographic channel into interest rates and inflation?

Outputs:
  - Table 7: CBI interactions (phase5_table7_cbi_interactions.md)
  - Table 8: IT interactions (phase5_table8_it_interactions.md)
  - Table 9: Combined regimes (phase5_table9_combined_regimes.md)
"""

import sys
from pathlib import Path
import numpy as np
import pandas as pd

# ── paths ────────────────────────────────────────────────────────────────────
PROJECT = Path("/mnt/c/demographics_capital_flows")
sys.path.insert(0, str(PROJECT / "multilateral" / "src"))
from model import PanelGLS

DATA_PATH = PROJECT / "monetary" / "data" / "processed" / "monetary_panel.csv"
TABLE_DIR = PROJECT / "monetary" / "output" / "tables"
TABLE_DIR.mkdir(parents=True, exist_ok=True)

OECD_38 = [
    "AUS", "AUT", "BEL", "CAN", "CHL", "COL", "CRI", "CZE", "DNK", "EST",
    "FIN", "FRA", "DEU", "GRC", "HUN", "ISL", "IRL", "ISR", "ITA", "JPN",
    "KOR", "LVA", "LTU", "LUX", "MEX", "NLD", "NZL", "NOR", "POL", "PRT",
    "SVK", "SVN", "ESP", "SWE", "CHE", "TUR", "GBR", "USA",
]


# ── helpers ──────────────────────────────────────────────────────────────────
def run_model(panel, dep_var, rhs_vars):
    """Run PanelGLS; return dict with coefs/se/pvals/r2/nobs or None."""
    cols = [dep_var] + rhs_vars + ["iso3", "year"]
    df = panel[cols].dropna()
    if len(df) < 50:
        print(f"  SKIP {dep_var}: only {len(df)} obs after dropna")
        return None
    try:
        y = df[dep_var].values
        X = df[rhs_vars].values
        gls = PanelGLS()
        gls.fit(y, X, df["iso3"].values, df["year"].values)
        return {
            "coefs": dict(zip(rhs_vars, gls.beta)),
            "se": dict(zip(rhs_vars, gls.se)),
            "pvals": dict(zip(rhs_vars, gls.pvalues)),
            "r2": gls.r_squared,
            "nobs": gls.n_obs,
            "ncountries": gls.n_countries,
        }
    except Exception as e:
        print(f"  ERROR {dep_var}: {e}")
        return None


def star(p):
    """Significance stars."""
    if p < 0.01:
        return "***"
    elif p < 0.05:
        return "**"
    elif p < 0.10:
        return "*"
    return ""


def fmt_coef(res, var):
    """Format coefficient as 'coef(se)***'."""
    if res is None or var not in res["coefs"]:
        return ""
    c = res["coefs"][var]
    s = res["se"][var]
    p = res["pvals"][var]
    return f"{c:.3f}{star(p)} ({s:.3f})"


def build_markdown_table(title, row_vars, col_labels, results, footer_rows):
    """Build a pipe-format markdown table."""
    lines = [f"# {title}\n"]
    header = "| Variable | " + " | ".join(col_labels) + " |"
    sep = "|" + "---|" * (len(col_labels) + 1)
    lines.append(header)
    lines.append(sep)
    for var in row_vars:
        cells = [fmt_coef(r, var) for r in results]
        lines.append(f"| {var} | " + " | ".join(cells) + " |")
    lines.append(sep)
    for label, vals in footer_rows:
        lines.append(f"| {label} | " + " | ".join(vals) + " |")
    return "\n".join(lines)


def construct_interactions(panel):
    """Construct Z x CBI and Z x IT interaction terms if missing."""
    df = panel.copy()

    # CBI interactions
    for z in ["Z_1", "Z_2", "Z_3"]:
        col = f"{z}_x_cbi"
        if col not in df.columns:
            df[col] = df[z] * df["cbi_index"]

    # IT interactions
    for z in ["Z_1", "Z_2", "Z_3"]:
        col = f"{z}_x_it"
        if col not in df.columns:
            df[col] = df[z] * df["it_adopter"]

    return df


# ── main ─────────────────────────────────────────────────────────────────────
def main():
    print("=" * 70)
    print("PHASE 5: CBI AND INFLATION TARGETING AS REGIME MODERATORS")
    print("=" * 70)

    panel = pd.read_csv(DATA_PATH)
    print(f"Panel: {len(panel):,} obs, {panel['iso3'].nunique()} countries")

    panel = construct_interactions(panel)

    oecd = panel[panel["iso3"].isin(OECD_38)].copy()
    pre_gfc = panel[panel["year"] <= 2007].copy()
    post_gfc = panel[panel["year"] >= 2008].copy()

    print(f"OECD:     {len(oecd):,} obs, {oecd['iso3'].nunique()} countries")
    print(f"Pre-GFC:  {len(pre_gfc):,} obs, {pre_gfc['iso3'].nunique()} countries")
    print(f"Post-GFC: {len(post_gfc):,} obs, {post_gfc['iso3'].nunique()} countries")

    z_vars = ["Z_1", "Z_2", "Z_3"]
    controls_rate = ["rgdp_growth", "inflation", "fiscal_bal_gdp", "kaopen", "nfa_gdp_lag"]
    controls_infl = ["rgdp_growth", "output_gap", "fiscal_bal_gdp", "kaopen", "nfa_gdp_lag"]

    # ── TABLE 7: CBI Interactions ────────────────────────────────────────
    print("\n── Table 7: CBI Interactions ──")

    cbi_interactions = ["Z_1_x_cbi", "Z_2_x_cbi", "Z_3_x_cbi"]
    rhs_cbi_rate = z_vars + ["cbi_index"] + cbi_interactions + controls_rate
    rhs_cbi_infl = z_vars + ["cbi_index"] + cbi_interactions + controls_infl

    # Columns: real_bond_10y + inflation, dedup Full/OECD when identical
    results_t7 = []
    col_labels_t7 = []
    row_vars_t7 = list(dict.fromkeys(
        z_vars + ["cbi_index"] + cbi_interactions + controls_rate + controls_infl
    ))

    for dv, rhs in [("real_bond_10y", rhs_cbi_rate), ("inflation", rhs_cbi_infl)]:
        res_full = run_model(panel, dv, rhs)
        res_oecd = run_model(oecd, dv, rhs)
        fn = res_full['nobs'] if res_full else 0
        on = res_oecd['nobs'] if res_oecd else 0
        if fn == on and fn > 0:
            results_t7.append(res_full)
            lbl = "10y rate (OECD*)" if "bond" in dv else "Inflation (OECD*)"
            col_labels_t7.append(lbl)
            print(f"  {lbl}: N={fn} (effectively OECD-only)")
        else:
            for label, res in [(f"{dv} (Full)", res_full), (f"{dv} (OECD)", res_oecd)]:
                results_t7.append(res)
                col_labels_t7.append(label)
        for res, label in [(res_full, dv)]:
            if res:
                for ivar in cbi_interactions:
                    c = res["coefs"].get(ivar, np.nan)
                    p = res["pvals"].get(ivar, np.nan)
                    print(f"  {label}: {ivar} = {c:.3f} (p={p:.3f})")

    footer_t7 = [
        ("R²", [f"{r['r2']:.3f}" if r else "" for r in results_t7]),
        ("N obs", [f"{r['nobs']:,}" if r else "" for r in results_t7]),
        ("N countries", [f"{r['ncountries']}" if r else "" for r in results_t7]),
    ]

    md_t7 = build_markdown_table(
        "Table 7: CBI Interactions with Demographics",
        row_vars_t7, col_labels_t7, results_t7, footer_t7,
    )
    print("\n" + md_t7)
    t7_path = TABLE_DIR / "phase5_table7_cbi_interactions.md"
    t7_path.write_text(md_t7)
    print(f"\nSaved: {t7_path}")

    # ── TABLE 8: IT Interactions ─────────────────────────────────────────
    print("\n── Table 8: IT Interactions ──")

    it_interactions = ["Z_1_x_it", "Z_2_x_it", "Z_3_x_it"]
    # Controls for inflation DV — exclude inflation itself from controls
    controls_no_infl = ["rgdp_growth", "output_gap", "fiscal_bal_gdp", "kaopen", "nfa_gdp_lag"]
    rhs_it = z_vars + ["it_adopter"] + it_interactions + controls_no_infl

    row_vars_t8 = z_vars + ["it_adopter"] + it_interactions + controls_no_infl

    t8_specs = [
        ("Full", panel),
        ("Pre-GFC", pre_gfc),
        ("Post-GFC", post_gfc),
    ]

    results_t8 = []
    col_labels_t8 = []

    for label, samp in t8_specs:
        res = run_model(samp, "inflation", rhs_it)
        results_t8.append(res)
        col_labels_t8.append(label)
        if res:
            for ivar in it_interactions:
                c = res["coefs"].get(ivar, np.nan)
                p = res["pvals"].get(ivar, np.nan)
                print(f"  {label}: {ivar} = {c:.3f} (p={p:.3f})")
        else:
            print(f"  {label}: no result")

    footer_t8 = [
        ("R²", [f"{r['r2']:.3f}" if r else "" for r in results_t8]),
        ("N obs", [f"{r['nobs']:,}" if r else "" for r in results_t8]),
        ("N countries", [f"{r['ncountries']}" if r else "" for r in results_t8]),
    ]

    md_t8 = build_markdown_table(
        "Table 8: Inflation Targeting Interactions with Demographics",
        row_vars_t8, col_labels_t8, results_t8, footer_t8,
    )
    print("\n" + md_t8)
    t8_path = TABLE_DIR / "phase5_table8_it_interactions.md"
    t8_path.write_text(md_t8)
    print(f"\nSaved: {t8_path}")

    # ── TABLE 9: Combined Regime Model ───────────────────────────────────
    print("\n── Table 9: Combined Regime Model ──")

    # NOTE: mi_index is identical to cbi_index in the data (correlation=1.0,
    # likely a data assembly error). Removed mi_index to avoid collinearity.
    rhs_combined = (z_vars + ["cbi_index", "it_adopter"]
                    + ["Z_1_x_cbi", "Z_1_x_it"] + controls_infl)
    row_vars_t9 = rhs_combined

    results_t9 = []
    col_labels_t9 = []

    res_full = run_model(panel, "inflation", rhs_combined)
    res_oecd = run_model(oecd, "inflation", rhs_combined)
    fn = res_full['nobs'] if res_full else 0
    on = res_oecd['nobs'] if res_oecd else 0
    if fn == on and fn > 0:
        results_t9.append(res_full)
        col_labels_t9.append("OECD*")
        print(f"  N={fn} (effectively OECD-only)")
    else:
        for label, res in [("Full", res_full), ("OECD", res_oecd)]:
            results_t9.append(res)
            col_labels_t9.append(label)
    if res_full:
        for ivar in ["Z_1_x_cbi", "Z_1_x_it"]:
            c = res_full["coefs"].get(ivar, np.nan)
            p = res_full["pvals"].get(ivar, np.nan)
            print(f"  {ivar} = {c:.3f} (p={p:.3f})")

    footer_t9 = [
        ("R²", [f"{r['r2']:.3f}" if r else "" for r in results_t9]),
        ("N obs", [f"{r['nobs']:,}" if r else "" for r in results_t9]),
        ("N countries", [f"{r['ncountries']}" if r else "" for r in results_t9]),
    ]

    md_t9 = build_markdown_table(
        "Table 9: Combined Regime Model (DV: Inflation)",
        row_vars_t9, col_labels_t9, results_t9, footer_t9,
    )
    print("\n" + md_t9)
    t9_path = TABLE_DIR / "phase5_table9_combined_regimes.md"
    t9_path.write_text(md_t9)
    print(f"\nSaved: {t9_path}")

    # ── Summary ──────────────────────────────────────────────────────────
    print("\n── Key Findings ──")

    # CBI moderating effect — find inflation result in deduped list
    res_cbi_full = next((r for r, l in zip(results_t7, col_labels_t7)
                         if "nflation" in l), None)
    if res_cbi_full and "Z_1_x_cbi" in res_cbi_full["coefs"]:
        c = res_cbi_full["coefs"]["Z_1_x_cbi"]
        p = res_cbi_full["pvals"]["Z_1_x_cbi"]
        direction = "dampens" if c < 0 else "amplifies"
        sig = "significant" if p < 0.10 else "NOT significant"
        print(f"  CBI moderation (inflation, full): Z_1_x_cbi = {c:.3f} (p={p:.3f})")
        print(f"    -> Higher CBI {direction} demographic effect on inflation ({sig})")

    # IT moderating effect
    res_it_full = results_t8[0]  # Full sample
    if res_it_full and "Z_1_x_it" in res_it_full["coefs"]:
        c = res_it_full["coefs"]["Z_1_x_it"]
        p = res_it_full["pvals"]["Z_1_x_it"]
        direction = "dampens" if c < 0 else "amplifies"
        sig = "significant" if p < 0.10 else "NOT significant"
        print(f"  IT moderation (inflation, full): Z_1_x_it = {c:.3f} (p={p:.3f})")
        print(f"    -> Inflation targeting {direction} demographic effect ({sig})")

    # Pre vs post GFC comparison
    if results_t8[1] and results_t8[2]:
        pre_c = results_t8[1]["coefs"].get("Z_1_x_it", np.nan)
        pre_p = results_t8[1]["pvals"].get("Z_1_x_it", np.nan)
        post_c = results_t8[2]["coefs"].get("Z_1_x_it", np.nan)
        post_p = results_t8[2]["pvals"].get("Z_1_x_it", np.nan)
        print(f"\n  IT moderation pre-GFC:  Z_1_x_it = {pre_c:.3f} (p={pre_p:.3f})")
        print(f"  IT moderation post-GFC: Z_1_x_it = {post_c:.3f} (p={post_p:.3f})")

    print("\nPhase 5 complete.")


if __name__ == "__main__":
    main()
